package io.parapet.messaging

import com.typesafe.scalalogging.StrictLogging
import io.parapet.messaging.Utils._
import io.parapet.messaging.ZmqAsyncServer.Worker
import io.parapet.messaging.api.MessagingApi._
import io.parapet.messaging.api.ServerAPI.Envelope
import io.parapet.messaging.api.{ErrorCodes, HeartbeatAPI}
import io.parapet.core.Dsl.DslF
import io.parapet.core.Event.{Start, Stop}
import io.parapet.core.{Encoder, Event, Process, ProcessRef}
import org.zeromq._

import scala.collection.mutable
import scala.util.Try

/**
  * This implementation is based on ROUTER socket type.
  * Message consists of three zmq frames:
  * {{{
  *   0 - client identity. generated by ZMQ.ROUTER socket or sent by a client
  *   1 - request id
  *   2 - actual parapet event
  * }}}
  */
class ZmqAsyncServer[F[_]](address: String,
                           identity: String,
                           service: ProcessRef,
                           encoder: Encoder,
                           numOfWorkers: Int) extends Process[F] with StrictLogging {

  import dsl._

  private lazy val zmqContext = new ZContext(numOfWorkers)
  private lazy val frontend = zmqContext.createSocket(SocketType.ROUTER)

  override def handle: Receive = {
    case Start =>
      createWorkers ++
        eval {
        //  val frontend = zmqContext.createSocket(SocketType.ROUTER)
          if (identity != "") {
            frontend.setIdentity(identity.getBytes())
          }
          val backend = zmqContext.createSocket(SocketType.DEALER)
          frontend.bind(address)
          backend.bind("inproc://backend")
          Try(ZMQ.proxy(frontend, backend, null)) // blocking call. can be interrupted
        }
    case Stop =>
      eval(Try(frontend.close())) ++ Utils.close(zmqContext)
  }

  def createWorkers: DslF[F, Unit] = {
    (0 until numOfWorkers).map(i =>
      register(ref, Worker[F](i, ref, zmqContext, encoder, service))
    ).fold(unit)(_ ++ _)
  }
}


object ZmqAsyncServer {

  /**
    * Worker is strictly synchronous, i.e. it waits for response from service until starts receiving messages
    * from zmq socket.
    * Note:  this class can be modified to support asynchronous dialog by using [[io.parapet.core.Channel]]
    * and [[io.parapet.core.Dsl.FlowOps.fork]] operator.
    *
    * @param id      worker unique id
    * @param parent  parent process, i.e. async server
    * @param ctx     ZMQ context
    * @param encoder event encoder
    * @param service the service to send requests received from socket
    * @tparam F en effect type
    */
  class Worker[F[_]](
                      id: Int,
                      parent: ProcessRef,
                      ctx: ZContext,
                      encoder: Encoder,
                      service: ProcessRef) extends Process[F] with StrictLogging {

    import dsl._

    override val ref: ProcessRef = ProcessRef(s"$parent-worker-$id")

    // request id to client identity
    private val pendingRequests = mutable.Map[String, PendingRequest]()

    private lazy val socket = ctx.createSocket(SocketType.DEALER)


    override def handle: Receive = {
      case Start => eval {
        socket.connect("inproc://backend")
      } ++ Await ~> ref

      case Await =>
        evalWith(Try(ZMsg.recvMsg(socket)).toOption) {
          case Some(msg) =>
            val clientId = msg.pop() // can't be null
          val reqId = msg.popString()
            val data = msg.pop()
            if (reqId == null) {
              eval(logger.debug(s"request id has not been sent by client: ${clientId.toString}. Ignore request"))
            } else {
              tryEval[F, Event](Try(encoder.read(data.getData)), {
                case Request(HeartbeatAPI.Ping) =>
                  val reply = new ZMsg
                  reply.add(clientId)
                  reply.add(reqId)
                  reply.add(encoder.write(Success(HeartbeatAPI.Pong)))
                  eval(reply.send(socket, true)) ++ Await ~> ref
                case Request(event) =>
                  eval(pendingRequests.put(reqId, PendingRequest(clientId))) ++ Envelope(reqId, event) ~> service
              }, err => {
                val res = new ZMsg
                res.add(clientId)
                res.add(reqId)
                res.add(encoder.write(Failure(s"Server failed to decode request. Error: ${err.getMessage}",
                  ErrorCodes.EncodingError)))
                eval(res.send(socket, true)) ++ Await ~> ref
              })
            }

          case None => eval(logger.debug("zmq context was interrupted"))
        }

      //  process reply from service
      case Envelope(requestId, res) =>
        evalWith(pendingRequests.remove(requestId)) {
          case Some(PendingRequest(clientId)) => eval {
            val reply = new ZMsg
            reply.add(clientId)
            reply.add(requestId)
            reply.add(encoder.write(Success(res)))
            reply.send(socket, true)
          }
          case None => eval(logger.error(s"Reply cannot be sent. unknown request id: $requestId"))
        } ++ Await ~> ref
    }
  }

  object Worker {
    def apply[F[_]](
                     id: Int,
                     parent: ProcessRef,
                     ctx: ZContext,
                     encoder: Encoder,
                     service: ProcessRef): Process[F] = new Worker(id, parent, ctx, encoder, service)
  }

  private object Await extends Event

  def apply[F[_]](address: String,
                  service: ProcessRef,
                  encoder: Encoder,
                  numOfWorkers: Int = 1,
                  identity: String = ""): Process[F] =
    new ZmqAsyncServer(address, identity, service, encoder, numOfWorkers)


  case class PendingRequest(clientId: ZFrame)

}